import gym
import numpy as np
import torch
import math
from scipy.special import expit
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.metrics import precision_score, recall_score, f1_score
from gym.envs.registration import register
from model.trainer.score.CQR import CQR
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning, module="gym.utils.passive_env_checker")

register(
    id="ts-regression-v0",
    entry_point="model.environment.conformal_rl:CRL",
)

class CRL(gym.Env):
    def __init__(self, data, window_size=10, output_dim=1, alpha=0.05,
                 calibration_size=50, device="cuda",
                 score_function=None, gamma=0.005,
                 seed=None, **kwargs):
        super(CRL, self).__init__()
        self.data = data
        self.window_size = window_size
        self.output_dim = output_dim
        self.alpha = alpha
        self.seq_len, self.feature_dim = data.shape
        self.step_count = 0
        self.seed = seed
        self.device = device
        self.calibration_size = calibration_size
        if score_function == 'CQR':
            self.score_function = score_function
            self.alpha_t = alpha
            self.gamma = gamma

        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf,
                                                shape=(window_size, self.feature_dim),
                                                dtype=np.float32)
        
        self.action_space = gym.spaces.Box(low=-np.inf, high=np.inf,
                                           shape=(self.output_dim,), dtype=np.float32)

        self.current_index = window_size
        self.residuals = []
        self.scores = []
        self.prediction_interval_buffer, self.ground_truth_buffer = [], []
        self.precision_buffer, self.recall_buffer, self.f1_buffer = [], [], []


    def reset(self, seed=None, options={}):
        super().reset(seed=self.seed)
        self.current_index = self.window_size
        self.residuals = []

        if self.score_function == "CQR":
            self.score_function = CQR()

        info = {}
        info.update({"mse": 0.0, "mae": 0.0})
        info.update({"alpha_t": 0.0, "q_hat": 0.0})
        info.update({"interval_len": 0.0, 'upper': 0.0, 'lower': 0.0, 'precision': 0.0, 'ground_truth': 0.0})

        state = self.data[:self.window_size, :].astype(np.float32)

        return state, info

    def step(self, action):
        assert self.current_index >= self.window_size, "current_index > window_size"
        true_value = self.data[self.current_index-1, -1]
        true_value = np.array([true_value])
        pred_value = action[:1]

        done = False
        truncated = False
        info = {}

        mse = mean_squared_error(true_value, pred_value)
        mae = mean_absolute_error(true_value, pred_value)

        error = np.abs(true_value - pred_value)
        if not self.prediction_interval_buffer:
            error_t = 0
        else:
            error_t = self.calculate_error_rate(x_batch=self.data[self.current_index-self.window_size:self.current_index-1, -1],
                                                y_batch_last=self.ground_truth_buffer[-1],
                                                pred_interval_last=self.prediction_interval_buffer[-1])
            error_t = error_t.item()
        pred_ = torch.Tensor(action).reshape(1,-1).to(self.device)
        truth_ = torch.Tensor(true_value).unsqueeze(-1).to(self.device)
        if len(pred_.shape) == 1:
            pred_.unsqueeze(-1)
        if len(truth_.shape) == 1:
            truth_.unsqueeze(-1)

        self.scores.append(self.calculate_score(pred_, truth_).item())
        self.alpha_t, self.q_hat = self.update_conformal_value(error_t)

        lower, upper = self.compute_confidence_interval(predicts_batch=action, q_hat=self.q_hat)
        reward = torch.tanh(1 - error)

        true_anomaly = (true_value < lower) | (true_value > upper)
        model_anomaly = (pred_value < lower) | (pred_value > upper)

        if true_anomaly & model_anomaly:
            reward += 1.0
        elif true_anomaly & ~model_anomaly:
            reward -= 1.0
        elif ~true_anomaly & model_anomaly:
            reward -= 1.0

        self.residuals.append(error)
        self.ground_truth_buffer.append(true_value)
        self.prediction_interval_buffer.append(torch.stack([torch.tensor(upper).to(self.device), torch.tensor(lower).to(self.device)], dim=0))

        precision = precision_score(true_anomaly, model_anomaly, zero_division=1)
        recall = recall_score(true_anomaly, model_anomaly, zero_division=1)
        f1 = f1_score(true_anomaly, model_anomaly, zero_division=1)
        interval_len = min(np.mean(upper - lower), 10000.0)

        if len(self.ground_truth_buffer) <= self.calibration_size:
            norm_value = np.mean(self.ground_truth_buffer)
        else:
            norm_value = np.mean(self.ground_truth_buffer[-self.calibration_size:])
        reward += 1 - interval_len / (norm_value + 1e-6)

        info.update({"precision": precision, "recall": recall, "f1": f1})
        info.update({"mse": mse, "mae": mae})
        info.update({"ground_truth": true_value.item()})
        info.update({"interval_len": interval_len, 'upper': upper, 'lower': lower})
        info.update({"alpha_t": self.alpha_t, "q_hat": self.q_hat.item()})
        self.step_count += 1

        next_index = self.current_index + 1
        if next_index < self.seq_len:
            next_state = self.data[next_index - self.window_size + 1: next_index + 1, :]
        else:
            next_state = np.zeros((self.window_size, self.feature_dim), dtype=np.float32)
            done = True
            truncated = True

        self.current_index = next_index
        info = {k: np.clip(v, -1e8, 1e8) for k, v in info.items()}

        if done:
            info["final_info"] = {
                "mse": info["mse"],
                "mae": info["mae"],
            }

        return next_state.astype(np.float32), np.mean(reward).item(), bool(done), bool(truncated), info

    def update_conformal_value(self, error_t):
        size = min(self.calibration_size, len(self.scores))
        alpha_t = max(1 / (size + 1),
                      min(0.9999, self.alpha_t + self.gamma * (self.alpha - error_t)))
        if len(self.scores) < self.calibration_size:
            q_hat = self.calculate_conformal_value(self.scores, alpha_t) # alpha_t
        else:
            q_hat = self.calculate_conformal_value(self.scores[-self.calibration_size:], alpha_t) # alpha_t

        return alpha_t, q_hat

    def calculate_score(self, predicts, y_truth):
        return self.score_function(predicts, y_truth)

    def calculate_conformal_value(self, scores, alpha):
        if alpha >= 1 or alpha <= 0:
            raise ValueError("Significance level 'alpha' must be in [0,1].")
        if len(scores) == 0:
            warnings.warn(
                f"The number of scores is 0, which is a invalid scores. To avoid program crash, the threshold is set as {torch.inf}.")
            return torch.inf
        N = len(scores)
        quantile_value = math.ceil((N + 1) * (1 - alpha)) / N
        scores_ = torch.Tensor(scores).to(self.device)
        if quantile_value > 1:
            warnings.warn(
                f"The value of quantile exceeds 1. It should be a value in [0,1]. To avoid program crash, the threshold is set as {torch.inf}.")
            return torch.inf

        return torch.kthvalue(scores_, math.ceil(N * quantile_value), dim=0).values.view(1).to(self.device)

    def calculate_error_rate(self, x_batch, y_batch_last, pred_interval_last):
            steps_t = len(y_batch_last)
            w_s = (steps_t - torch.arange(steps_t)).to(self.device)
            w_s = torch.pow(0.95, w_s)
            w_s = w_s / torch.sum(w_s)

            if len(w_s.shape) == 1:
                w_s = w_s.unsqueeze(1)

            if isinstance(x_batch, np.ndarray):  # Check if it's a NumPy array
                x_batch = torch.from_numpy(x_batch).to(self.device)
                y_batch_last = torch.from_numpy(y_batch_last).to(self.device)

            err = x_batch.new_zeros(steps_t, self.q_hat.shape[0])
            err = ((y_batch_last >= pred_interval_last[..., 1]) | (y_batch_last <= pred_interval_last[..., 0])).int()
            err_t = torch.sum(w_s * err)

            return err_t

    def compute_confidence_interval(self, predicts_batch=None, q_hat=None):
        if self.score_function is None:
            if len(self.residuals) < 10:
                return -10.0, 10.0

            residuals_arr = np.array(self.residuals)
            calibration_errors = np.array(residuals_arr[-self.calibration_size:])
            error_quantile = np.quantile(calibration_errors, 1 - self.alpha)

            last_actual = self.data[self.current_index - 1, -1]
            lower = last_actual - error_quantile
            upper = last_actual + error_quantile
            return lower, upper
        else:
            interval = self.score_function.generate_intervals(predicts_batch, q_hat)
            lower = interval.squeeze()[0].item()
            upper = interval.squeeze()[1].item()
            return lower, upper


if __name__ == "__main__":
    data = np.random.randn(100000, 21)
    
    env = CRL(data)
    env.reset()

    for i in range(100):
        action = np.random.randn(10,1)
        next_state, reward, done, truncated, info = env.step(action)
        print(next_state.shape, reward, done, info)

